{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predicting Acute Toxicity LD50\n",
    "\n",
    "Here we show a worked example applying graphein to process a [molecule dataset](https://tdcommons.ai/single_pred_tasks/tox/) from [TDC](https://tdcommons.ai/).\n",
    "\n",
    "**Dataset Description**: Acute toxicity LD50 measures the most conservative dose that can lead to lethal adverse effects. The higher the dose, the more lethal of a drug. This dataset is kindly provided by the authors of [1].\n",
    "\n",
    "**Task Description**: Regression. Given a drug SMILES string, predict its acute toxicity.\n",
    "\n",
    "**Dataset Statistics**: 7,385 drugs.\n",
    "\n",
    "**Dataset Split**: Random Split Scaffold Split\n",
    "\n",
    "[1] Zhu, Hao, et al. “Quantitative structure− activity relationship modeling of rat acute toxicity by oral exposure.” Chemical research in toxicology 22.12 (2009): 1913-1921.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/a-r-j/graphein/blob/master/notebooks/molecule_model_tutorial_tox.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/a-r-j/graphein/blob/master/notebooks/molecule_model_tutorial_tox.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install Graphein if necessary\n",
    "# !pip install graphein[extras]\n",
    "\n",
    "# Install TDC if necessary\n",
    "# !pip install PyTDC\n",
    "\n",
    "# NB you may need to install DL libraries such as pytorch, pytorch-lightning and torch-geometric\n",
    "# These are left to the user to configure as they depend on your particular desired configuration (e.g CUDA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NBVAL_SKIP\n",
    "from tdc.single_pred import Tox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found local copy...\n",
      "Loading...\n",
      "Done!\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Drug_ID</th>\n",
       "      <th>Drug</th>\n",
       "      <th>Y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Methane, tribromo-</td>\n",
       "      <td>BrC(Br)Br</td>\n",
       "      <td>2.343</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Bromoethene (9CI)</td>\n",
       "      <td>C=CBr</td>\n",
       "      <td>2.330</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1,1'-Biphenyl, hexabromo-</td>\n",
       "      <td>Brc1ccc(-c2ccc(Br)c(Br)c2Br)c(Br)c1Br</td>\n",
       "      <td>1.465</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Isothiocyanic acid, p-bromophenyl ester</td>\n",
       "      <td>S=C=Nc1ccc(Br)cc1</td>\n",
       "      <td>2.729</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Benzene, bromo-</td>\n",
       "      <td>Brc1ccccc1</td>\n",
       "      <td>1.765</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   Drug_ID  \\\n",
       "0                       Methane, tribromo-   \n",
       "1                        Bromoethene (9CI)   \n",
       "2                1,1'-Biphenyl, hexabromo-   \n",
       "3  Isothiocyanic acid, p-bromophenyl ester   \n",
       "4                          Benzene, bromo-   \n",
       "\n",
       "                                    Drug      Y  \n",
       "0                              BrC(Br)Br  2.343  \n",
       "1                                  C=CBr  2.330  \n",
       "2  Brc1ccc(-c2ccc(Br)c(Br)c2Br)c(Br)c1Br  1.465  \n",
       "3                      S=C=Nc1ccc(Br)cc1  2.729  \n",
       "4                             Brc1ccccc1  1.765  "
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# NBVAL_SKIP\n",
    "# Load data\n",
    "data = Tox(name = 'LD50_Zhu')\n",
    "split = data.get_split()\n",
    "split[\"train\"].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Molecular Graphs with Graphein"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NBVAL_SKIP\n",
    "import torch\n",
    "import graphein.molecule as gm\n",
    "import graphein.ml as ml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                   \r"
     ]
    }
   ],
   "source": [
    "# NBVAL_SKIP\n",
    "config = gm.MoleculeGraphConfig()\n",
    "\n",
    "# Iterate over dataframes containing each split\n",
    "train_graphs = [gm.construct_graph(smiles=smiles, config=config) for smiles in split[\"train\"][\"Drug\"]]\n",
    "valid_graphs = [gm.construct_graph(smiles=smiles, config=config) for smiles in split[\"valid\"][\"Drug\"]]\n",
    "test_graphs = [gm.construct_graph(smiles=smiles, config=config) for smiles in split[\"test\"][\"Drug\"]]\n",
    "\n",
    "# Assign labels to graphs\n",
    "train_graphs = ml.add_labels_to_graph(train_graphs, labels=split[\"train\"][\"Y\"].apply(torch.tensor), name=\"graph_label\")\n",
    "valid_graphs = ml.add_labels_to_graph(valid_graphs, labels=split[\"valid\"][\"Y\"].apply(torch.tensor), name=\"graph_label\")\n",
    "test_graphs = ml.add_labels_to_graph(test_graphs, labels=split[\"test\"][\"Y\"].apply(torch.tensor), name=\"graph_label\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NBVAL_SKIP\n",
    "from graphein.ml import GraphFormatConvertor\n",
    "from torch_geometric.loader import DataLoader\n",
    "\n",
    "# Define a conversion object\n",
    "convertor = GraphFormatConvertor(\n",
    "    src_format=\"nx\",\n",
    "    dst_format=\"pyg\",\n",
    "    columns=[\"edge_index\", \"atom_type_one_hot\", \"graph_label\"]\n",
    "    )\n",
    "\n",
    "# Convert Graphs from NX to PyG\n",
    "train_graphs = [convertor(g) for g in train_graphs]\n",
    "valid_graphs = [convertor(g) for g in valid_graphs]\n",
    "test_graphs = [convertor(g) for g in test_graphs]\n",
    "\n",
    "# Create Dataloaders\n",
    "train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)\n",
    "valid_loader = DataLoader(valid_graphs, batch_size=32, shuffle=False)\n",
    "test_loader = DataLoader(test_graphs, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 546], node_id=[32], atom_type_one_hot=[529, 11], graph_label=[32], num_nodes=529, batch=[529], ptr=[33])\n"
     ]
    }
   ],
   "source": [
    "# NBVAL_SKIP\n",
    "# Inspect a batch\n",
    "for i in train_loader:\n",
    "    print(i)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NBVAL_SKIP\n",
    "from torch_geometric.nn import GCNConv, global_add_pool\n",
    "from torch.nn.functional import mse_loss\n",
    "from torch.nn import functional as F\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NBVAL_SKIP\n",
    "config_default = dict(\n",
    "    n_hid = 8,\n",
    "    n_out = 8,\n",
    "    batch_size = 4,\n",
    "    dropout = 0.5,\n",
    "    lr = 0.001,\n",
    "    num_heads = 32,\n",
    "    num_att_dim = 64,\n",
    "    model_name = 'GCN'\n",
    ")\n",
    "\n",
    "class Struct:\n",
    "    def __init__(self, **entries):\n",
    "        self.__dict__.update(entries)\n",
    "\n",
    "config = Struct(**config_default)\n",
    "\n",
    "global model_name\n",
    "model_name = config.model_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NBVAL_SKIP\n",
    "class GraphNets(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layer1 = GCNConv(in_channels=11, out_channels=config.n_hid)\n",
    "        self.layer2 = GCNConv(in_channels=config.n_hid, out_channels=config.n_out)\n",
    "        self.decoder = nn.Linear(config.n_out, 1)\n",
    "\n",
    "    def forward(self, g):\n",
    "        x = g.atom_type_one_hot.float()\n",
    "        x = F.dropout(x, p=config.dropout, training=self.training)\n",
    "        x = F.elu(self.layer1(x, g.edge_index))\n",
    "        x = F.dropout(x, p=config.dropout, training=self.training)\n",
    "        x = self.layer2(x, g.edge_index)\n",
    "        x = global_add_pool(x, batch=g.batch)\n",
    "        x = self.decoder(x)\n",
    "        return x\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x = batch\n",
    "        y = x.graph_label\n",
    "        y_hat = self(x)\n",
    "        loss = mse_loss(y_hat, y)\n",
    "\n",
    "        self.log(\"train_loss\", loss)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        x = batch\n",
    "        y = x.graph_label\n",
    "        y_hat = self(x)\n",
    "        loss = mse_loss(y_hat, y)\n",
    "        self.log(\"valid_loss\", loss)\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        x = batch\n",
    "        y = x.graph_label\n",
    "        y_hat = self(x)\n",
    "        loss = mse_loss(y_hat, y)\n",
    "\n",
    "        self.log(\"test_loss\", loss)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam(self.parameters(), lr=config.lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name    | Type    | Params\n",
      "------------------------------------\n",
      "0 | layer1  | GCNConv | 96    \n",
      "1 | layer2  | GCNConv | 72    \n",
      "2 | decoder | Linear  | 9     \n",
      "------------------------------------\n",
      "177       Trainable params\n",
      "0         Non-trainable params\n",
      "177       Total params\n",
      "0.001     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 49: 100%|██████████| 186/186 [00:01<00:00, 179.13it/s, loss=0.932, v_num=2]\n"
     ]
    }
   ],
   "source": [
    "# NBVAL_SKIP\n",
    "model = GraphNets()\n",
    "\n",
    "trainer = pl.Trainer(max_epochs=50, gpus=1, strategy=None)\n",
    "trainer.fit(model, train_loader, valid_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|██████████| 1477/1477 [00:03<00:00, 443.90it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.8866580128669739     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.8866580128669739    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[{'test_loss': 0.8866580128669739}]"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# NBVAL_SKIP\n",
    "trainer.test(model, dataloaders=[test_loader])"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "0ab7f988027852efc1ebacd06db3f130eb65d2a20cb6a366311359132c20a952"
  },
  "kernelspec": {
   "display_name": "Python 3.8.13 ('graphein')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
